#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from typing import List, Optional, Dict, Literal
from collections import defaultdict
import re
import torch
from torch import nn

from transformers.cache_utils import DynamicCache

from msmodelslim.utils.logging import get_logger, logger_setter
from msmodelslim.core.QAL.qregistry import QABCRegistry
from msmodelslim.core.base.protocol import BatchProcessRequest
from msmodelslim.quant.processor.base import AutoSessionProcessor, AutoProcessorConfig
from msmodelslim.quant.quantizer.attention import DynamicCacheQuantizer
from msmodelslim.utils.config_map import ConfigSet
from msmodelslim.utils.hook_utils import add_before_hook, add_after_hook, restore_target
from msmodelslim.quant.ir import FakeQuantDynamicCache
from msmodelslim.quant.quantizer.base import QConfig
from msmodelslim.core.QAL.qbase import QScope


HOOK_TARGET = (DynamicCache, 'update')
CACHE_INPUT_NAME = ("key_states", "value_states")
LAYER_IDX_NAME = "layer_idx"


class DynamicCacheProcessorConfig(AutoProcessorConfig):
    type: Literal['dynamic_cache'] = "dynamic_cache"
    qconfig: QConfig
    include: List[str] = []
    exclude: List[str] = []


def _warning_unmatched_pattern(name: str, config_set: ConfigSet) -> None:
    unmatched_keys = config_set.unmatched_keys()
    unmatched_keys = list(filter(lambda x: x != "*", unmatched_keys))
    if unmatched_keys:
        get_logger().warning(
            f"These {name} patterns are not matched any module, please ensure this is as expected: {unmatched_keys}")


def _get_module_by_name(model: nn.Module, submodule_key: str) -> nn.Module:
    """根据名称获取模块"""
    module_tokens = submodule_key.split('.')
    cur_mod = model
    for s in module_tokens:
        cur_mod = getattr(cur_mod, s)
    return cur_mod


def _detect_attention_layers(model: torch.nn.Module) -> Dict[int, str]:
    """
    Detect all attention layer prefixes in the model.
    Identifies attention modules by checking if class name contains 'attention'.
    """
    attention_layers = {}
    
    for name, module in model.named_modules():
        class_name = module.__class__.__name__.lower()
        if 'attention' in class_name or 'attn' in class_name:
            # Extract layer index from module name
            numbers = re.findall(r'\.(\d+)\.', name)
            if numbers:
                layer_idx = int(numbers[0])
                attention_layers[layer_idx] = name
    
    return attention_layers


def _get_first_layer(model: torch.nn.Module):
    attention_layers = _detect_attention_layers(model)
    layer_name = '.'.join(attention_layers[0].split('.')[:-1])
    mod = _get_module_by_name(model, layer_name)
    return mod


@QABCRegistry.register(dispatch_key=DynamicCacheProcessorConfig, abc_class=AutoSessionProcessor)
@logger_setter("msmodelslim.quant.processor.quant.attention")
class DynamicCacheQuantProcessor(AutoSessionProcessor):
    def __init__(
            self,
            model: nn.Module,
            config: DynamicCacheProcessorConfig,
            adapter: Optional[object] = None,
    ):
        super().__init__(model)
        self.config = config
        if self.config.qconfig.scope != QScope.PER_CHANNEL:
            raise ValueError("DynamicCacheQuantProcessor only supports per_channel quantization!")

        self.include = ConfigSet(config.include) if config.include else ConfigSet(["*"])
        self.exclude = ConfigSet(config.exclude) if config.exclude else ConfigSet([])

        self.input_name_map = {i: key for i, key in enumerate(CACHE_INPUT_NAME)}
        self.input_layer_idx_name = LAYER_IDX_NAME

        self.cache_quantizers: Dict[int, Dict[str, DynamicCacheQuantizer]] = defaultdict(
            lambda: {quant_name: None for quant_name in self.input_name_map.values()}
        )
        # Track quantizer ready status
        self.quantizer_ready: Dict[int, Dict[str, bool]] = defaultdict(
            lambda: {quant_name: False for quant_name in self.input_name_map.values()}
        )
        self.fake_kvcache_quantizers: Dict[int, Dict[str, FakeQuantDynamicCache]] = defaultdict(
            lambda: {quant_name: None for quant_name in self.input_name_map.values()}
        )
        # add trigger hook on module who first uses kvcache
        self.first_layer = _get_first_layer(self.model)
        self.trigger_hook_target = (self.first_layer, 'forward')
        self.cache_target = HOOK_TARGET
        self._trigger_hook_installed = False
        self._use_global_hook = False
        self._attention_layers_map = _detect_attention_layers(self.model)
        # Hook registry to avoid duplicate hook installation using cache IDs
        self._installed_cache_ids = set()

    def is_data_free(self) -> bool:
        return False

    def need_kv_cache(self):
        return True

    def support_distributed(self) -> bool:
        return False

    def pre_run(self) -> None:
        attention_layers = _detect_attention_layers(self.model)
        for layer_idx, _ in attention_layers.items():
            self._create_quantizer(layer_idx)

        # add fake quantize hook
        add_before_hook(self.trigger_hook_target, self._add_quantizer_hook)
        self._trigger_hook_installed = True

    def postprocess(self, _: BatchProcessRequest) -> None:
        _warning_unmatched_pattern("include", self.include)
        _warning_unmatched_pattern("exclude", self.exclude)
        attention_layers = _detect_attention_layers(self.model)
        for layer_idx, attention_prefix in attention_layers.items():
            mod = _get_module_by_name(self.model, attention_prefix)
            self._deploy_quantizer(mod, layer_idx)
    
    def post_run(self) -> None:
        # remove global hook if used
        if self._trigger_hook_installed:
            if self._use_global_hook:
                # Remove global hook
                restore_target(self.cache_target)
                self._use_global_hook = False
            # Remove trigger hook
            restore_target(self.trigger_hook_target)
            self._trigger_hook_installed = False
        
        # Clear hook registry
        self._installed_cache_ids.clear()

        # Install execution hook for fake quantization
        add_before_hook(self.trigger_hook_target, self._add_fake_quant_hook)
        self._trigger_hook_installed = True

    def _create_quantizer(self, layer_idx: int):
        for _, target_name in self.input_name_map.items():
            if self.cache_quantizers[layer_idx][target_name] is None:
                self.cache_quantizers[layer_idx][target_name] = DynamicCacheQuantizer(self.config.qconfig)

    def _deploy_quantizer(self, mod: nn.Module, layer_idx: int):
        # 只有当量化器准备好时才部署
        for _, target_name in self.input_name_map.items():
            if self.quantizer_ready[layer_idx][target_name] and self.cache_quantizers[layer_idx][target_name]:
                mod.add_module(f'{target_name}_quantizer', self.cache_quantizers[layer_idx][target_name].deploy())
                self.fake_kvcache_quantizers[layer_idx][target_name] = getattr(mod, f'{target_name}_quantizer')

    def _add_quantizer_hook(self, _, kwargs):
        for _, value in kwargs.items():
            if isinstance(value, self.cache_target[0]):
                # Check if hook already installed using cache ID
                cache_id = id(value)
                if cache_id in self._installed_cache_ids:
                    return
                target = (value, self.cache_target[1])
                add_after_hook(target, self._cache_update_hook)
                self._installed_cache_ids.add(cache_id)
                return
        if not self._use_global_hook:
            get_logger().warning(f"No {self.cache_target[0].__name__} found in the model forward function arguments"
                        f"try to hook on {self.cache_target[0].__name__}.{self.cache_target[1]}, "
                        "this may influence other model's inference!")
            add_after_hook(self.cache_target, self._cache_update_hook)
            self._use_global_hook = True

    def _add_fake_quant_hook(self, _, kwargs):
        for _, value in kwargs.items():
            if isinstance(value, self.cache_target[0]):
                # Check if hook already installed using cache ID
                cache_id = id(value)
                if cache_id in self._installed_cache_ids:
                    return
                target = (value, self.cache_target[1])
                add_after_hook(target, self._fake_quant_update)
                self._installed_cache_ids.add(cache_id)
                return
        if not self._use_global_hook:
            get_logger().warning(f"No {self.cache_target[0].__name__} found in the model forward function arguments"
                        f"try to hook on {self.cache_target[0].__name__}.{self.cache_target[1]}, "
                        "this may influence other model's inference!")
            add_after_hook(self.cache_target, self._fake_quant_update)
            self._use_global_hook = True

    def _cache_update_hook(self, _, kwargs, result):
        layer_idx = kwargs.get(self.input_layer_idx_name)
        for _, target_name in self.input_name_map.items():
            states = kwargs.get(target_name)
            
            if self._attention_layers_map[layer_idx] not in self.include:
                return result

            if self._attention_layers_map[layer_idx] in self.exclude:
                return result
                
            # Update key and value quantization observers
            
            quantizer = self.cache_quantizers[layer_idx][target_name]
            if quantizer is not None:
                quantizer(states)
                self.quantizer_ready[layer_idx][target_name] = True

        return result

    def _fake_quant_update(self, _, kwargs, result):
        layer_idx = kwargs.get(self.input_layer_idx_name)
        res = []
        for idx, target_name in self.input_name_map.items():
            if isinstance(result, tuple):
                states = result[idx]
            else:
                states = result
        
            if self._attention_layers_map[layer_idx] not in self.include:
                return result

            if self._attention_layers_map[layer_idx] in self.exclude:
                return result
            
            quantizer = self.fake_kvcache_quantizers[layer_idx][target_name]
            if quantizer is not None:
                states = quantizer(states)
            res.append(states)
        return tuple(res) if len(result) > 1 else res[0]

